import av
import numpy
import tellopy
import cv2
import os
import json
import math
import time
import sys
import traceback

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision.transforms as transforms

import torch.nn as nn
import torch.utils.data
import numpy as np
from opt import opt

from dataloader import ImageLoader, DetectionLoader, DetectionProcessor, DataWriter, Mscoco
from yolo.util import write_results, dynamic_write_results
from SPPE.src.main_fast_inference import *

import os
import sys
from tqdm import tqdm
import time
from fn import getTime

from pPose_nms import pose_nms, write_json

'''
使用tello无人机的回传图像
'''

# img_path = 'duan_alphapose\\photo\\'
# posePoint_Path = 'duan_alphapose\\'

# 设置参数
args = opt
args.inputpath = '/Users/yunyi/Desktop/AlphaPose/duan_alphapose/photo/'
args.outputpath = '/Users/yunyi/Desktop/AlphaPose/duan_alphapose/'
args.sp = True
args.dataset = 'coco'

img_path = args.inputpath

if not args.sp:
    torch.multiprocessing.set_start_method('forkserver', force=True)
    torch.multiprocessing.set_sharing_strategy('file_system')

# 获取图像
def GetImage(args):
    inputpath = args.inputpath
    inputlist = args.inputlist
    mode = args.mode

    if not os.path.exists(args.outputpath):
        os.mkdir(args.outputpath)

    for root, dirs, files in os.walk(inputpath):
        im_names = files
    return im_names

# 下载模型
def downModel():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Load pose model
    print('Loading YOLO model...')
    pose_dataset = Mscoco()
    if args.fast_inference:
        pose_model = InferenNet_fast(4 * 1 + 1, pose_dataset)
    else:
        pose_model = InferenNet(4 * 1 + 1, pose_dataset)
    pose_model.to(device)
    pose_model.eval()

    return pose_model

# 处理图像，提取关键点
def Alphapose(im_names, pose_model,):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Load input images
    data_loader = ImageLoader(im_names, batchSize=args.detbatch, format='yolo').start()

    # Load detection loader
    sys.stdout.flush()
    det_loader = DetectionLoader(data_loader, batchSize=args.detbatch).start()
    det_processor = DetectionProcessor(det_loader).start()
    runtime_profile = {
        'dt': [],
        'pt': [],
        'pn': []
    }

    # Init data writer
    writer = DataWriter(args.save_video).start()

    data_len = data_loader.length()
    im_names_desc = tqdm(range(data_len))

    batchSize = args.posebatch
    for i in im_names_desc:
        start_time = getTime()
        with torch.no_grad():
            (inps, orig_img, im_name, boxes, scores, pt1, pt2) = det_processor.read()
            if boxes is None or boxes.nelement() == 0:
                writer.save(None, None, None, None, None, orig_img, im_name.split('/')[-1])
                continue
            ckpt_time, det_time = getTime(start_time)
            runtime_profile['dt'].append(det_time)
            # Pose Estimation

            datalen = inps.size(0)
            leftover = 0
            if (datalen) % batchSize:
                leftover = 1
            num_batches = datalen // batchSize + leftover
            hm = []
            for j in range(num_batches):
                inps_j = inps[j*batchSize:min((j +  1)*batchSize, datalen)].to(device)
                hm_j = pose_model(inps_j)
                hm.append(hm_j)
            hm = torch.cat(hm)
            ckpt_time, pose_time = getTime(ckpt_time)
            runtime_profile['pt'].append(pose_time)
            hm = hm.cpu()
            writer.save(boxes, scores, hm, pt1, pt2, orig_img, im_name.split('/')[-1])

            ckpt_time, post_time = getTime(ckpt_time)
            runtime_profile['pn'].append(post_time)

        if args.profile:
            # TQDM
            im_names_desc.set_description(
            'det time: {dt:.3f} | pose time: {pt:.2f} | post processing: {pn:.4f}'.format(
                dt=np.mean(runtime_profile['dt']), pt=np.mean(runtime_profile['pt']), pn=np.mean(runtime_profile['pn']))
            )

    print('Finish Model Running.')
    if (args.save_img or args.save_video) and not args.vis_fast:
        print('===========================> Rendering remaining images in the queue...')
        print('===========================> If this step takes too long, you can enable the --vis_fast flag to use fast rendering (real-time).')
    while(writer.running()):
        pass
    writer.stop()
    final_result = writer.results()
    # write_json(final_result, args.outputpath)
    if final_result[0]['result']:
        return final_result[0]['result'][0]['keypoints']
    else:
        return None

# 根据关键点分析动作，绘制图像
def PoseFind(point_results):
    '''point_result是关键点信息，一个tensor数组
    {0,  "Nose"},
    {1,  "LEye"},
    {2,  "REye"},
    {3,  "LEar"},
    {4,  "REar"},
    {5,  "LShoulder"},
    {6,  "RShoulder"},
    {7,  "LElbow"},
    {8,  "RElbow"},
    {9,  "LWrist"},
    {10, "RWrist"},
    {11, "LHip"},
    {12, "RHip"},
    {13, "LKnee"},
    {14, "Rknee"},
    {15, "LAnkle"},
    {16, "RAnkle"},
    '''

    #登记LH左手、LE左肘、LS左肩、RH右手、RE右肘、RS右肩数据
    #LS_X = int(numbers[0]['keypoints'][15])
    LS_Y = int(point_results[5][1].item())
    #RS_X = int(numbers[0]['keypoints'][18])
    RS_Y = int(point_results[6][1].item())

    #LE_X = int(numbers[0]['keypoints'][21])
    LE_Y = int(point_results[7][1].item())
    #RE_X = int(numbers[0]['keypoints'][24])
    RE_Y = int(point_results[8][1].item())

    #LH_X = int(numbers[0]['keypoints'][27])
    LH_Y = int(point_results[9][1].item())
    #RH_X = int(numbers[0]['keypoints'][30])
    RH_Y = int(point_results[10][1].item())

    #以双眼间距的两倍作为参照
    Leye_x = int(point_results[1][0].item())
    Leye_y = int(point_results[1][1].item())
    Reye_x = int(point_results[2][0].item())
    Reye_y = int(point_results[2][1].item())
    len = int(math.sqrt(math.pow(Leye_x-Reye_x,2)+math.pow(Leye_y-Reye_y,2))) * 2

    '''
    print(LH_Y,LE_Y,LS_Y)
    print(RH_Y,RE_Y,RS_Y)
    print(len)
    '''

    #判断姿势
    pose = 0
    if LH_Y - LE_Y >= len and LE_Y - LS_Y >= len and RH_Y - RE_Y >= len and RE_Y - RS_Y >= len:
        pose = 1#双垂
        print('---------------------------双垂--------------------------')
    if LE_Y - LH_Y >= len and abs(LE_Y - LS_Y) <= len and RE_Y - RH_Y >= len and abs(RE_Y - RS_Y) <= len:
        pose = 2#双平举
        print('---------------------------双平举--------------------------')
    if LS_Y - LE_Y >= len and LE_Y - LH_Y >= len and RS_Y - RE_Y >= len and RE_Y - RH_Y >= len:
        pose = 3#双高举
        print('---------------------------双高举--------------------------')
    if abs(LH_Y - LE_Y) <= len and abs(LE_Y - LS_Y) <= len and abs(RH_Y - RE_Y) <= len and abs(RE_Y - RS_Y) <= len:
        pose = 4#双伸
        print('---------------------------双伸--------------------------')
    if abs(LH_Y - LE_Y) <= len and abs(LE_Y - LS_Y) <= len and RE_Y - RH_Y >= len and abs(RE_Y - RS_Y) <= len:
        pose = 5#左伸右平举
        print('---------------------------左伸右平举--------------------------')
    if abs(LH_Y - LE_Y) <= len and abs(LE_Y - LS_Y) <= len and RH_Y - RE_Y >= len and RE_Y - RS_Y >= len:
        pose = 6#左伸右垂
        print('---------------------------左伸右垂--------------------------')
    if LE_Y - LH_Y >= len and abs(LE_Y - LS_Y) <= len and abs(RH_Y - RE_Y) <= len and abs(RE_Y - RS_Y) <= len:
        pose = 7#左平举右伸
        print('---------------------------左平举右伸--------------------------')
    if LH_Y - LE_Y >= len and LE_Y - LS_Y >= len and abs(RH_Y - RE_Y) <= len and abs(RE_Y - RS_Y) <= len:
        pose = 8#左垂右伸
        print('---------------------------左垂右伸--------------------------')
    if LS_Y - LE_Y >= len and LE_Y - LH_Y >= len and RH_Y - RE_Y >= len and RE_Y - RS_Y >= len:
        pose = 9#左高举右垂
        print('---------------------------左高举右垂--------------------------')
    if LH_Y - LE_Y >= len and LE_Y - LS_Y >= len and RS_Y - RE_Y >= len and RE_Y - RH_Y >= len:
        pose = 10#右高举左垂
        print('---------------------------右高举左垂--------------------------')
    else:
        pass

    #绘制关键点可视化
    # 0鼻子，1左眼，2右眼，3左耳，4右耳，5左肩，6右肩，7左肘，8右肘，9左腕，10右腕，11左臀，12右臀，13左膝，14右膝，15左踝，16右踝
    point_list = [(int(point_results[0][0].item()),int(point_results[0][1].item())), (int(point_results[1][0].item()),int(point_results[1][1].item())),
                    (int(point_results[2][0].item()),int(point_results[2][1].item())), (int(point_results[3][0].item()),int(point_results[3][1].item())),
                    (int(point_results[4][0].item()),int(point_results[4][1].item())), (int(point_results[5][0].item()),int(point_results[5][1].item())),
                    (int(point_results[6][0].item()),int(point_results[6][1].item())), (int(point_results[7][0].item()),int(point_results[7][1].item())),
                    (int(point_results[8][0].item()),int(point_results[8][1].item())), (int(point_results[9][0].item()),int(point_results[9][1].item())),
                    (int(point_results[10][0].item()),int(point_results[10][1].item())), (int(point_results[11][0].item()),int(point_results[11][1].item())),
                    (int(point_results[12][0].item()),int(point_results[12][1].item())), (int(point_results[13][0].item()),int(point_results[13][1].item())),
                    (int(point_results[14][0].item()),int(point_results[14][1].item())), (int(point_results[15][0].item()),int(point_results[15][1].item())),
                    (int(point_results[16][0].item()),int(point_results[16][1].item()))]
    img = cv2.imread(img_path+'frame.jpg')
    for point in point_list:
        cv2.circle(img, point, 1, (0,0,255), 4)

    #绘制关键点连线
    cv2.line(img, point_list[0], point_list[1], (0,255,0), 1,4)
    cv2.line(img, point_list[0], point_list[2], (0,255,0), 1,4)
    cv2.line(img, point_list[1], point_list[3], (0,255,0), 1,4)
    cv2.line(img, point_list[2], point_list[4], (0,255,0), 1,4)
    cv2.line(img, point_list[9], point_list[7], (0,255,0), 1,4)
    cv2.line(img, point_list[7], point_list[5], (0,255,0), 1,4)
    cv2.line(img, point_list[5], point_list[6], (0,255,0), 1,4)
    cv2.line(img, point_list[6], point_list[8], (0,255,0), 1,4)
    cv2.line(img, point_list[8], point_list[10],  (0,255,0), 1,4)
    cv2.line(img, point_list[5], point_list[11],  (0,255,0), 1,4)
    cv2.line(img, point_list[6], point_list[12], (0,255,0), 1,4)
    cv2.line(img, point_list[11], point_list[12],  (0,255,0), 1,4)
    cv2.line(img, point_list[11], point_list[13],  (0,255,0), 1,4)
    cv2.line(img, point_list[13], point_list[15], (0,255,0), 1,4)
    cv2.line(img, point_list[12], point_list[14], (0,255,0), 1,4)
    cv2.line(img, point_list[14], point_list[16], (0,255,0), 1,4)

    cv2.imshow('AlphaPose',img)
    cv2.waitKey(1)
    return pose

#删除文件函数，用于清空文件夹内图像
def del_files(path_file):
    ls = os.listdir(path_file)
    for i in ls:
        f_path = os.path.join(path_file, i)
        # 判断是否是一个目录,若是,则递归删除
        if os.path.isdir(f_path):
            del_files(f_path)
        else:
            os.remove(f_path)


def main():
    drone = tellopy.Tello()

    try:
        drone.connect()
        drone.wait_for_connection(60.0)

        container = av.open(drone.get_video_stream())
        # skip first 300 frames
        frame_skip = 300

        # 加载模型
        PoseModel = downModel()
        while True:
            for frame in container.decode(video=0):
                if 0 < frame_skip:
                    frame_skip = frame_skip - 1
                    continue
                start_time = time.time()
                image = cv2.cvtColor(numpy.array(frame.to_image()), cv2.COLOR_RGB2BGR)
                cv2.imshow('Original', image)
                
                cv2.imwrite(img_path + 'frame.jpg',image)#保存图片
                #安装配置alphapose
                #运行alphapose demo，提取人体关键点

                # 载入图片
                img_names=GetImage(args)

                # alphapose处理
                point_results = Alphapose(img_names, PoseModel)

                # 分析动作
                if point_results is not None:
                    pose=PoseFind(point_results)
                    print(pose)
                else:
                    pass

                del_files(img_path)#删除文件夹内图片

                if frame.time_base < 1.0 / 60:
                    time_base = 1.0 / 60
                else:
                    time_base = frame.time_base
                frame_skip = int((time.time() - start_time) / time_base)


    except Exception as ex:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        traceback.print_exception(exc_type, exc_value, exc_traceback)
        print(ex)
    finally:
        drone.quit()
        cv2.destroyAllWindows()

if __name__ == '__main__':
    main()
